{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-strategy workflow with reflection\n",
    "\n",
    "In this notebook we'll demonstrate a workflow that attempts 3 different query strategies in parallel and picks the best one.\n",
    "\n",
    "As shown in the diagram below:\n",
    "* First the quality of the query is judged. If it's a bad query, a `BadQueryEvent` is emitted and the `improve_query` step will try to improve the quality of the query before trying again. This is reflection.\n",
    "* Once an acceptable query has been found, three simultaneous events are emitted: a `NaiveRAGEvent`, a `HighTopKEvent`, and a `RerankEvent`.\n",
    "* Each of these events is picked up by a dedicated step that tries a different RAG strategy on the same index. All 3 emit a `ResponseEvent`\n",
    "* The `judge` step waits until it has collected all three `ResponseEvents`, then it compares them. It finally emits the best response as a `StopEvent`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Screenshot 2024-08-16 at 12.41.23 PM.png]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install dependencies\n",
    "\n",
    "We need LlamaIndex, the file reader (for reading PDFs), the workflow visualizer (to draw the diagram above), and OpenAI to embed the data and query an LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index-core llama-index-llms-openai llama-index-utils-workflow llama-index-readers-file llama-index-embeddings-openai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the data\n",
    "\n",
    "We are using 3 long PDFs of San Francisco's annual budget from 2016 through 2018."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir data\n",
    "!wget \"https://www.dropbox.com/scl/fi/xt3squt47djba0j7emmjb/2016-CSF_Budget_Book_2016_FINAL_WEB_with-cover-page.pdf?rlkey=xs064cjs8cb4wma6t5pw2u2bl&dl=0\" -O \"data/2016-CSF_Budget_Book_2016_FINAL_WEB_with-cover-page.pdf\"\n",
    "!wget \"https://www.dropbox.com/scl/fi/jvw59g5nscu1m7f96tjre/2017-Proposed-Budget-FY2017-18-FY2018-19_1.pdf?rlkey=v988oigs2whtcy87ti9wti6od&dl=0\" -O \"data/2017-Proposed-Budget-FY2017-18-FY2018-19_1.pdf\"\n",
    "!wget \"https://www.dropbox.com/scl/fi/izknlwmbs7ia0lbn7zzyx/2018-o0181-18.pdf?rlkey=p5nv2ehtp7272ege3m9diqhei&dl=0\" -O \"data/2018-o0181-18.pdf\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bring in dependencies\n",
    "\n",
    "Now we import all our dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from llama_index.core import (\n",
    "    SimpleDirectoryReader,\n",
    "    VectorStoreIndex,\n",
    "    StorageContext,\n",
    "    load_index_from_storage,\n",
    ")\n",
    "from llama_index.core.workflow import (\n",
    "    step,\n",
    "    Context,\n",
    "    Workflow,\n",
    "    Event,\n",
    "    StartEvent,\n",
    "    StopEvent,\n",
    ")\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core.postprocessor.rankGPT_rerank import RankGPTRerank\n",
    "from llama_index.core.query_engine import RetrieverQueryEngine\n",
    "from llama_index.core.chat_engine import SimpleChatEngine\n",
    "from llama_index.utils.workflow import draw_all_possible_flows"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also need to set up our OpenAI key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.colab import userdata\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = userdata.get(\"openai-key\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define event classes\n",
    "\n",
    "Our flow generates quite a few different event types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class JudgeEvent(Event):\n",
    "    query: str\n",
    "\n",
    "\n",
    "class BadQueryEvent(Event):\n",
    "    query: str\n",
    "\n",
    "\n",
    "class NaiveRAGEvent(Event):\n",
    "    query: str\n",
    "\n",
    "\n",
    "class HighTopKEvent(Event):\n",
    "    query: str\n",
    "\n",
    "\n",
    "class RerankEvent(Event):\n",
    "    query: str\n",
    "\n",
    "\n",
    "class ResponseEvent(Event):\n",
    "    query: str\n",
    "    response: str\n",
    "\n",
    "\n",
    "class SummarizeEvent(Event):\n",
    "    query: str\n",
    "    response: str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define workflow\n",
    "\n",
    "This is the substance of our workflow, so let's break it down:\n",
    "\n",
    "* `load_or_create_index` is a normal RAG function that reads our PDFs from disk and indexes them if they aren't already indexed. If indexing already happened it will simply restore the existing index from disk.\n",
    "\n",
    "* `judge_query` does a few things\n",
    "  * It initializes the LLM and calls `load_or_create_index` to get set up. It stores these things in the context so they are available later.\n",
    "  * It judges the quality of the query\n",
    "  * If the query is bad it emits a `BadQueryEvent`\n",
    "  * If the query is good it emits a `NaiveRAGEvent`, a `HighTopKEvent` and a `RerankerEvent`\n",
    "\n",
    "* `improve_query` takes the `BadQueryEvent` and uses an LLM to try and expand and deambiguate the query if possible, then it loops back to `judge_query`\n",
    "\n",
    "* `naive_rag`, `high_top_k` and `rerank` accept their respective events and attempt 3 different RAG strategies. Each emits a `ResponseEvent` with their result and a `source` parameter that says which strategy was used\n",
    "\n",
    "* `judge` fires every time a `ResponseEvent` is emitted, but it uses `collect_events` to buffer them until it has received all 3. Then it sends the responses to an LLM and asks it to select the \"best\" one. It emits the best response as a StopEvent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ComplicatedWorkflow(Workflow):\n",
    "    def load_or_create_index(self, directory_path, persist_dir):\n",
    "        # Check if the index already exists\n",
    "        if os.path.exists(persist_dir):\n",
    "            print(\"Loading existing index...\")\n",
    "            # Load the index from disk\n",
    "            storage_context = StorageContext.from_defaults(\n",
    "                persist_dir=persist_dir\n",
    "            )\n",
    "            index = load_index_from_storage(storage_context)\n",
    "        else:\n",
    "            print(\"Creating new index...\")\n",
    "            # Load documents from the specified directory\n",
    "            documents = SimpleDirectoryReader(directory_path).load_data()\n",
    "\n",
    "            # Create a new index from the documents\n",
    "            index = VectorStoreIndex.from_documents(documents)\n",
    "\n",
    "            # Persist the index to disk\n",
    "            index.storage_context.persist(persist_dir=persist_dir)\n",
    "\n",
    "        return index\n",
    "\n",
    "    @step\n",
    "    async def judge_query(\n",
    "        self, ctx: Context, ev: StartEvent | JudgeEvent\n",
    "    ) -> BadQueryEvent | NaiveRAGEvent | HighTopKEvent | RerankEvent:\n",
    "        # initialize\n",
    "        llm = await ctx.store.get(\"llm\", default=None)\n",
    "        if llm is None:\n",
    "            await ctx.store.set(\"llm\", OpenAI(model=\"gpt-4o\", temperature=0.1))\n",
    "            await ctx.store.set(\n",
    "                \"index\", self.load_or_create_index(\"data\", \"storage\")\n",
    "            )\n",
    "\n",
    "            # we use a chat engine so it remembers previous interactions\n",
    "            await ctx.store.set(\"judge\", SimpleChatEngine.from_defaults())\n",
    "\n",
    "        response = await ctx.store.get(\"judge\").chat(\n",
    "            f\"\"\"\n",
    "            Given a user query, determine if this is likely to yield good results from a RAG system as-is. If it's good, return 'good', if it's bad, return 'bad'.\n",
    "            Good queries use a lot of relevant keywords and are detailed. Bad queries are vague or ambiguous.\n",
    "\n",
    "            Here is the query: {ev.query}\n",
    "            \"\"\"\n",
    "        )\n",
    "        if response == \"bad\":\n",
    "            # try again\n",
    "            return BadQueryEvent(query=ev.query)\n",
    "        else:\n",
    "            # send query to all 3 strategies\n",
    "            self.send_event(NaiveRAGEvent(query=ev.query))\n",
    "            self.send_event(HighTopKEvent(query=ev.query))\n",
    "            self.send_event(RerankEvent(query=ev.query))\n",
    "\n",
    "    @step\n",
    "    async def improve_query(\n",
    "        self, ctx: Context, ev: BadQueryEvent\n",
    "    ) -> JudgeEvent:\n",
    "        response = await ctx.store.get(\"llm\").complete(\n",
    "            f\"\"\"\n",
    "            This is a query to a RAG system: {ev.query}\n",
    "\n",
    "            The query is bad because it is too vague. Please provide a more detailed query that includes specific keywords and removes any ambiguity.\n",
    "        \"\"\"\n",
    "        )\n",
    "        return JudgeEvent(query=str(response))\n",
    "\n",
    "    @step\n",
    "    async def naive_rag(\n",
    "        self, ctx: Context, ev: NaiveRAGEvent\n",
    "    ) -> ResponseEvent:\n",
    "        index = await ctx.store.get(\"index\")\n",
    "        engine = index.as_query_engine(similarity_top_k=5)\n",
    "        response = engine.query(ev.query)\n",
    "        print(\"Naive response:\", response)\n",
    "        return ResponseEvent(\n",
    "            query=ev.query, source=\"Naive\", response=str(response)\n",
    "        )\n",
    "\n",
    "    @step\n",
    "    async def high_top_k(\n",
    "        self, ctx: Context, ev: HighTopKEvent\n",
    "    ) -> ResponseEvent:\n",
    "        index = await ctx.store.get(\"index\")\n",
    "        engine = index.as_query_engine(similarity_top_k=20)\n",
    "        response = engine.query(ev.query)\n",
    "        print(\"High top k response:\", response)\n",
    "        return ResponseEvent(\n",
    "            query=ev.query, source=\"High top k\", response=str(response)\n",
    "        )\n",
    "\n",
    "    @step\n",
    "    async def rerank(self, ctx: Context, ev: RerankEvent) -> ResponseEvent:\n",
    "        index = await ctx.store.get(\"index\")\n",
    "        reranker = RankGPTRerank(top_n=5, llm=await ctx.store.get(\"llm\"))\n",
    "        retriever = index.as_retriever(similarity_top_k=20)\n",
    "        engine = RetrieverQueryEngine.from_args(\n",
    "            retriever=retriever,\n",
    "            node_postprocessors=[reranker],\n",
    "        )\n",
    "        response = engine.query(ev.query)\n",
    "        print(\"Reranker response:\", response)\n",
    "        return ResponseEvent(\n",
    "            query=ev.query, source=\"Reranker\", response=str(response)\n",
    "        )\n",
    "\n",
    "    @step\n",
    "    async def judge(self, ctx: Context, ev: ResponseEvent) -> StopEvent:\n",
    "        ready = ctx.collect_events(ev, [ResponseEvent] * 3)\n",
    "        if ready is None:\n",
    "            return None\n",
    "\n",
    "        response = await ctx.store.get(\"judge\").chat(\n",
    "            f\"\"\"\n",
    "            A user has provided a query and 3 different strategies have been used\n",
    "            to try to answer the query. Your job is to decide which strategy best\n",
    "            answered the query. The query was: {ev.query}\n",
    "\n",
    "            Response 1 ({ready[0].source}): {ready[0].response}\n",
    "            Response 2 ({ready[1].source}): {ready[1].response}\n",
    "            Response 3 ({ready[2].source}): {ready[2].response}\n",
    "\n",
    "            Please provide the number of the best response (1, 2, or 3).\n",
    "            Just provide the number, with no other text or preamble.\n",
    "        \"\"\"\n",
    "        )\n",
    "\n",
    "        best_response = int(str(response))\n",
    "        print(\n",
    "            f\"Best response was number {best_response}, which was from {ready[best_response-1].source}\"\n",
    "        )\n",
    "        return StopEvent(result=str(ready[best_response - 1].response))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Draw flow diagram\n",
    "\n",
    "This is how we get the diagram we showed at the start."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "draw_all_possible_flows(\n",
    "    ComplicatedWorkflow, filename=\"complicated_workflow.html\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the workflow\n",
    "\n",
    "Let's take the workflow for a spin:\n",
    "* The judge_query event returned nothing. This is because it used `send_event` instead. So the query was judged \"good\".\n",
    "* All 3 RAG steps run and generate different answers to the query\n",
    "* The `judge` step runs 3 times. The first 2 times it produces no event, because it has not collected the requisite 3 `ResponseEvent`s.\n",
    "* On the third time it selects the best response and returns a `StopEvent`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running step judge_query\n",
      "Creating new index...\n",
      "Step judge_query produced no event\n",
      "Running step naive_rag\n",
      "Naive response: Spending has increased over the years due to various factors such as new voter-approved minimum spending requirements, the creation of new voter-approved baselines, and growth in baseline funded requirements. Additionally, there have been notable changes in spending across different service areas and departments, with increases in funding for areas like public protection, transportation, and public works.\n",
      "Step naive_rag produced event ResponseEvent\n",
      "Running step rerank\n",
      "Reranker response: Spending has increased over the years, with notable changes in the allocation of funds to various service areas and departments. The budget reflects adjustments in spending to address evolving needs and priorities, resulting in a rise in overall expenditures across different categories.\n",
      "Step rerank produced event ResponseEvent\n",
      "Running step high_top_k\n",
      "High top k response: Spending has increased over the years, with the total budget showing growth in various areas such as aid assistance/grants, materials & supplies, equipment, debt service, services of other departments, and professional & contractual services. Additionally, there have been new investments in programs like workforce development, economic development, film services, and finance and administration. The budget allocations have been adjusted to accommodate changing needs and priorities, reflecting an overall increase in spending across different departments and programs.\n",
      "Step high_top_k produced event ResponseEvent\n",
      "Running step judge\n",
      "Step judge produced no event\n",
      "Running step judge\n",
      "Step judge produced no event\n",
      "Running step judge\n",
      "Best response was number 3, which was from High top k\n",
      "Step judge produced event StopEvent\n",
      "Spending has increased over the years, with the total budget showing growth in various areas such as aid assistance/grants, materials & supplies, equipment, debt service, services of other departments, and professional & contractual services. Additionally, there have been new investments in programs like workforce development, economic development, film services, and finance and administration. The budget allocations have been adjusted to accommodate changing needs and priorities, reflecting an overall increase in spending across different departments and programs.\n"
     ]
    }
   ],
   "source": [
    "c = ComplicatedWorkflow(timeout=120, verbose=True)\n",
    "result = await c.run(\n",
    "    # query=\"How has spending on police changed in San Francisco's budgets from 2016 to 2018?\"\n",
    "    # query=\"How has spending on healthcare changed in San Francisco?\"\n",
    "    query=\"How has spending changed?\"\n",
    ")\n",
    "print(result)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
